import gym
import highway_env
from stable_baselines3 import DDPG
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
import numpy as np
from stable_baselines3.common.monitor import Monitor
import datetime


env = gym.make('roundabout')
env = Monitor(env, './logsDDPG')  # 日志文件将保存在./logs目录下

# 添加噪声
# n_actions = env.action_space.shape[-1]
#action_noise = OrnsteinUhlenbeckActionNoise(mean=np.zeros(n_actions), sigma=float(0.5) * np.ones(n_actions))

# 初始化DDPG模型
model = DDPG("MlpPolicy",
             env,
             verbose=1,
             gamma=0.99,)  #action_noise=action_noise,

start_time = datetime.datetime.now()
model.learn(total_timesteps=int(5e4))
end_time = datetime.datetime.now()
total_time = end_time - start_time
print(f"Total training time: {total_time}")
# 保存模型
model.save("roundabout_ddpg_5_19/model")

model = DDPG.load("roundabout_ddpg_5_19/model")

number_of_collisions = 0
T = 1
sum_lane = 0
sum_comfort = 0
sum_efficiency = 0
for f in range(10):
  print(f)
  done = truncated = None
  obs, info = env.reset()
  while not (done or truncated):
    action, _states = model.predict(obs)
    obs, reward, done, truncated, info = env.step(action)
    # obs, reward, done, truncated, info = env.step(action.item(0))



    # print(info)
    # print(obs)
    # input("Press Enter to continue...")
    #
    #print(info['crashed'])

    cur_frame = env.render(mode="rgb_array")
    # out.write(cur_frame)

    # print(info)
    sum_lane = sum_lane + info['rewards']['lane_centering_reward']
    sum_comfort = sum_comfort + info['rewards']['comfort']
    sum_efficiency = sum_efficiency + info['rewards']['efficiency']
    print('average lane is', sum_lane / T)
    print('average comfort is', sum_comfort / T)
    print('average efficiency is', sum_efficiency / T)

    # print(info)

    # print(action)
    # print(obs)
    # print(info)
    # print(reward)
    # if info.get('crashed'):
    #     number_of_collisions += 1
    # env.render()
    # cur_frame = env.render(mode="rgb_array")
    # # out.write(cur_frame)
    # print('crashrate is '+str(float(number_of_collisions)/T)+' and T is'+str(T))
    T += 1


